'''
修改onnx模型
1、插入类别名
2、支持动态batch size
3、将原输入拆成box和score两个部分
4、插入batchnms节点
'''


import onnx
import onnx_graphsurgeon as gs
import numpy as np


NUM_BOXES = 25500
NUM_CLASSES = 1
BATCH_SIZE = -1

# 增加class_names和动态batchsize
labels = ['mouse']
names = {cls_id: label for cls_id, label in enumerate(labels)}

model = onnx.load("/wjr/develop/projects/yolov5/qat.onnx")

meta = model.metadata_props.add()
meta.key, meta.value = "names", str(names)

graph = model.graph
graph.input[0].type.tensor_type.shape.dim[0].dim_param = 'None'
onnx.save(model, f'/wjr/develop/projects/yolov5/qat2.onnx')


# 1.加载模型
graph = gs.import_onnx(onnx.load("/wjr/develop/projects/yolov5/qat2.onnx"))

# 首先将b*n*(4+NUM_CLASSES)的输出split成b*n*1*4和b*n*NUM_CLASSES两个部分
# 定义split节点
split = gs.Constant("split", values=np.array([4, NUM_CLASSES], dtype=np.int64))
output_box_0 = gs.Variable(
    name="output_box_0", shape=(BATCH_SIZE, NUM_BOXES, 4), dtype=np.float32)
output_scores = gs.Variable(
    name="output_scores", shape=(BATCH_SIZE, NUM_BOXES, NUM_CLASSES), dtype=np.float32)
split_node = gs.Node(
    op="Split",
    inputs=[graph.outputs[0], split],
    outputs=[output_box_0, output_scores],
    attrs={"axis": 2}
)

# box reshape，增加一个维度
box_shape = gs.Constant(
    "shape", values=np.array([BATCH_SIZE, NUM_BOXES, 1, 4], dtype=np.int64))
output_boxes = gs.Variable(
    name="output_boxes", shape=(BATCH_SIZE, NUM_BOXES, 1, 4), dtype=np.float32)
# 增加reshape节点
reshape_node = gs.Node(
    op="Reshape", inputs=[output_box_0, box_shape], outputs=[output_boxes])

# 定义batchednms节点
# keepTopK最终保留的目标数，是batchednms节点的属性
# num_detections返回每个输入检测到的目标数
keepTopK = 100
topK = 1000
num_detections = gs.Variable(
    name="num_detections", dtype=np.int32, shape=[BATCH_SIZE, 1])
# nmsed_boxes返回每个输入检测到box坐标
nmsed_boxes = gs.Variable(name="nmsed_boxes", dtype=np.float32, shape=[
                          BATCH_SIZE, keepTopK, 4])
# nmsed_scores返回每个输入检测到的box对应的score
nmsed_scores = gs.Variable(name="nmsed_scores",
                           dtype=np.float32, shape=[BATCH_SIZE, keepTopK])
# nmsed_classes返回每个输入检测到的box对应的类别id
nmsed_classes = gs.Variable(name="nmsed_classes",
                            dtype=np.float32, shape=[BATCH_SIZE, keepTopK])

new_outputs = [num_detections, nmsed_boxes, nmsed_scores, nmsed_classes]
# 创建nms节点
# 首先定义nms节点的属性
attrs = {}
attrs["shareLocation"] = False # nms不同类别box是否统一处理
attrs["backgroundLabelId"] = -1  # 背景类id，没有背景则设置为-1
attrs["numClasses"] = NUM_CLASSES
attrs["topK"] = topK    # number of bounding boxes for nms eg 1000s
attrs["keepTopK"] = keepTopK # bounding boxes to be kept per image eg 20
attrs["scoreThreshold"] = 0.50 #0.70
attrs["iouThreshold"] = 0.7
attrs["isNormalized"] = False   # 输入box坐标是否经过归一化
attrs["clipBoxes"] = False      # 当isNormalized为True时才生效，对box做clip处理
attrs['scoreBits'] = 16
# attrs["plugin_version"] = "1"
# attrs["caffeSemantics"] = True

nms_node = gs.Node(
    op="BatchedNMSDynamic_TRT",   # 不能使用BatchedNMS_TRT，不支持动态batch size
    attrs=attrs,
    inputs=[output_boxes, output_scores],
    outputs=new_outputs
)

# 将节点插入到nodes列表中
graph.nodes.extend([split_node, reshape_node, nms_node])
# 更新graph的输出
graph.outputs = new_outputs

graph.cleanup().toposort()
onnx.save(gs.export_onnx(graph), "/wjr/develop/projects/yolov5/qat3.onnx")
